#!/usr/bin/env python3
# E14 — Shapiro-like Delay (v2)
# Present-act engine (stdlib only). Control is pure boolean/ordinal; neighbor-only; NO RNG in control.
# Readouts (fits, CI, mesh checks) are diagnostics-only.

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

# ---------------- utilities ----------------
def utc_timestamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root: str, subs: List[str]) -> None:
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path: str, txt: str) -> None:
    with open(path, "w", encoding="utf-8") as f: f.write(txt)

def json_dump(path: str, obj: dict) -> None:
    with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True)

def sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n: int) -> int:
    return int(math.isqrt(n))

# ---------------- control schedule ----------------
def P_from_zones(r: int, zones: List[dict]) -> int:
    for z in zones:
        if int(z["r_min"]) <= r <= int(z["r_max"]): return int(z["period"])
    return 1

def P_discrete_1_over_r(r: int, kappa: int) -> int:
    r = max(1, int(r))
    return 1 + (kappa // r)  # integer-only; stronger inward

def period_at_radius(r: int, schedule: dict) -> int:
    mode = schedule.get("mode", "discrete_1_over_r")
    if mode == "zones":
        return P_from_zones(r, schedule["zones"])
    else:
        return P_discrete_1_over_r(r, int(schedule.get("kappa", 24)))

# ---------------- travel-time simulator ----------------
def travel_time_single_ray(N:int, cx:int, cy:int, y:int, x0:int, x1:int, schedule:dict,
                           guard:int=0) -> int:
    """Deterministic eastward ray along fixed y. Move +1 cell when (t % P(r) == 0)."""
    if guard <= 0:
        if schedule.get("mode","discrete_1_over_r") == "zones":
            pmax = max(int(z["period"]) for z in schedule["zones"])
        else:
            pmax = 1 + int(schedule.get("kappa", 24))
        guard = (x1 - x0) * max(2, pmax + 1)

    t, x = 0, x0
    while x < x1 and t < guard:
        dx = x - cx
        dy = y - cy
        r = isqrt(dx*dx + dy*dy)
        P = period_at_radius(r, schedule)  # integer predicate only
        if (t % P) == 0:
            x += 1
        t += 1
    if x < x1:
        raise RuntimeError("GUARD_EXCEEDED")
    return t  # total ticks to traverse

def baseline_time_length(x0:int, x1:int) -> int:
    return (x1 - x0)  # P(r)=1 everywhere => one cell per tick

# ---------------- diagnostics ----------------
def linreg_y_on_x(xs: List[float], ys: List[float]) -> Tuple[float, float, float]:
    """Return (slope, intercept, R^2) for y = a + b*x."""
    n = len(xs)
    if n < 2: return float("nan"), float("nan"), float("nan")
    xbar = sum(xs)/n; ybar = sum(ys)/n
    num = sum((x-xbar)*(y-ybar) for x,y in zip(xs,ys))
    den = sum((x-xbar)*(x-xbar) for x in xs)
    if den == 0: return float("nan"), float("nan"), float("nan")
    b = num/den
    a = ybar - b*xbar
    ss_tot = sum((y-ybar)*(y-ybar) for y in ys)
    ss_res = sum((y-(a+b*x))*(y-(a+b*x)) for x,y in zip(xs,ys))
    r2 = 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)
    return b, a, r2

def mesh_compare(a_coarse: float, a_fine: float) -> Tuple[float, float]:
    absd = abs(a_coarse - a_fine)
    denom = max(1e-12, 0.5*(abs(a_coarse)+abs(a_fine)))
    return absd, absd/denom

# ---------------- run panel ----------------
def run_panel(manifest: dict, gridN:int, outdir:str, tag:str) -> dict:
    N = gridN
    cx = int(manifest["grid"].get("cx", N//2))
    cy = int(manifest["grid"].get("cy", N//2))
    x_margin = int(manifest["source"].get("x_margin", 16))
    x0, x1 = x_margin, N - x_margin

    # impact parameters restricted to active slow-zone to avoid zero-delay plateau
    b_list = [int(b) for b in manifest["source"]["impact_params_shells"]]
    schedule = manifest["schedule"]

    # reference L for log(L/b). Keep constant across meshes for apples-to-apples slope.
    L_ref = int(manifest["fit"].get("L_ref_shells", 128))

    rows = []
    for b in b_list:
        y = cy + b
        if not (0 <= y < N):
            y = cy - b
            if not (0 <= y < N):
                continue
        t_on  = travel_time_single_ray(N, cx, cy, y, x0, x1, schedule)
        t_off = baseline_time_length(x0, x1)
        dly = t_on - t_off
        rows.append((b, t_on, t_off, dly))

    # metrics csv
    mpath = os.path.join(outdir, f"e14_{tag}_per_ray.csv")
    with open(mpath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["b_shells","t_on","t_off","delay"])
        for b, t_on, t_off, dly in rows: w.writerow([b, t_on, t_off, dly])

    # fit
    bs = [r[0] for r in rows]
    ys = [float(r[3]) for r in rows]
    xs = [math.log(L_ref / float(b)) for b in bs]
    slope, intercept, r2 = linreg_y_on_x(xs, ys)

    # monotone inward check (b desc ⇒ delay nondecreasing, with small tolerance)
    mono_ok = True
    tol = float(manifest["acceptance"].get("monotone_tol_frac", 0.02))
    for i in range(1, len(rows)):
        if ys[i] + tol*max(1.0, ys[i]) < ys[i-1]:
            mono_ok = False
            break

    return {
        "csv": mpath,
        "b": bs, "x": xs, "y": ys,
        "slope": slope, "intercept": intercept, "r2": r2,
        "mono_ok": mono_ok
    }

# ---------------- main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","outputs/mesh","logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        manifest = json.load(f)
    manifest_path = os.path.join(root, "config", "manifest_e14.json")
    json_dump(manifest_path, manifest)

    write_text(os.path.join(root,"logs","env.txt"),
               "\n".join([f"utc={utc_timestamp()}",
                          f"os={os.name}", f"cwd={os.getcwd()}",
                          f"python={sys.version.split()[0]}"]))

    # mesh panels
    N_coarse = int(manifest["grid"]["N"])
    N_fine   = int(manifest["mesh"].get("N_fine", 192))

    coarse = run_panel(manifest, N_coarse, os.path.join(root,"outputs/metrics"), "coarse")
    fine   = run_panel(manifest, N_fine,   os.path.join(root,"outputs/metrics"), "fine")

    # acceptance
    r2_min = float(manifest["acceptance"].get("r2_min", 0.96))
    amp_min, amp_max = [float(x) for x in manifest["acceptance"].get("amp_band_ticks", [0.0, 1e9])]
    dabs_max = float(manifest["mesh"].get("delta_amp_abs_max", 0.10))
    drel_max = float(manifest["mesh"].get("delta_amp_rel_max", 0.15))

    amp_ok_c = (amp_min <= coarse["slope"] <= amp_max)
    amp_ok_f = (amp_min <= fine["slope"]   <= amp_max)
    r2_ok_c  = (coarse["r2"] >= r2_min)
    r2_ok_f  = (fine["r2"]   >= r2_min)
    mono_ok  = bool(coarse["mono_ok"] and fine["mono_ok"])

    dabs, drel = mesh_compare(coarse["slope"], fine["slope"])
    mesh_ok = (dabs <= dabs_max) and (drel <= drel_max)

    passed = bool(amp_ok_c and amp_ok_f and r2_ok_c and r2_ok_f and mesh_ok and mono_ok)

    audit = {
        "sim": "E14_shapiro_delay_v2",
        "coarse": {k: coarse[k] for k in ["csv","slope","intercept","r2","mono_ok"]},
        "fine":   {k: fine[k]   for k in ["csv","slope","intercept","r2","mono_ok"]},
        "mesh": {"delta_amp_abs": dabs, "delta_amp_rel": drel, "ok": mesh_ok},
        "accept": {
            "r2_min": r2_min, "amp_band_ticks": [amp_min, amp_max],
            "delta_amp_abs_max": dabs_max, "delta_amp_rel_max": drel_max
        },
        "pass": passed,
        "manifest_hash": sha256_file(manifest_path)
    }
    json_dump(os.path.join(root,"outputs/audits","e14_audit.json"), audit)

    result_line = ("E14_v2 PASS={p} A_hat_coarse={ac:.6f} R2_c={r2c:.4f} "
                   "A_hat_fine={af:.6f} R2_f={r2f:.4f} ΔA={da:.6f} relΔ={dr:.3f}"
                   .format(p=passed, ac=coarse["slope"], r2c=coarse["r2"],
                           af=fine["slope"],   r2f=fine["r2"],
                           da=dabs, dr=drel))
    write_text(os.path.join(root,"outputs/run_info","result_line.txt"), result_line)
    print(result_line)

if __name__ == "__main__":
    main()
